{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantum State Discrimination\n",
    "\n",
    "\n",
    "<em> Copyright (c) 2021 Institute for Quantum Computing, Baidu Inc. All Rights Reserved. </em>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "Quantum state discrimination (QSD) [1-2] is a fundamental question in quantum communication, quantum computation, and quantum cryptography. In this tutorial, we will explain how to discriminate two orthogonal bipartite pure states $\\lvert\\psi\\rangle$ and $\\lvert\\phi\\rangle$, which satisfies $\\langle\\psi\\lvert\\phi\\rangle=0$, under the constraint of Local Operations and Classical Communication (LOCC). We refer all the theoretical details to the original paper [3]."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## QSD Protocol\n",
    "\n",
    "Firstly, we want to make the problem definition clear. Consider two spatially separated parties $A$ (Alice) and $B$ (Bob) share a given two-qubit system, the system state is $\\lvert\\varphi\\rangle$ previously distributed by another party $C$ (Charlie). Alice and Bob were only notified that $\\lvert\\varphi\\rangle$ is either $\\lvert\\psi\\rangle$ or $\\lvert\\phi\\rangle$ (both are pure states), satisfying $\\langle\\psi\\lvert\\phi\\rangle=0$. Then, Charlie provides many copies of $\\lvert\\psi\\rangle$ and $\\lvert\\phi\\rangle$ to them, and he asks Alice and Bob to cooperate with each other to figure out which state they are actually sharing.\n",
    "\n",
    "\n",
    "Solving this problem under our LOCCNet framework is trivial. As always, let's start with the simplest one-round LOCC protocol with a QNN architecture shown in Figure 1. Then, the difficult lies in the design of an appropriate loss function $L$. Since we choose to let both parties to measure their subsystem, there will be four possible measurement results $m_Am_B\\in\\{00, 01, 10, 11\\}$. To distinguish $\\lvert\\psi\\rangle$ and $\\lvert\\phi\\rangle$, we will label the former state with measurement results $m_Am_B\\in\\{00, 10\\}$ and the latter with $m_Am_B\\in\\{01, 11\\}$. This step can be understood as adding labels to the data in supervised learning. With these labels, we can define the loss function as the probability of guessing wrong label,\n",
    "\n",
    "$$\n",
    "L = p_{\\lvert\\psi\\rangle\\_01}+p_{\\lvert\\psi\\rangle\\_11}+p_{\\lvert\\phi\\rangle\\_10}+p_{\\lvert\\phi\\rangle\\_00}.\n",
    "\\tag{1}\n",
    "$$\n",
    "\n",
    "where $p_{\\lvert\\psi\\rangle\\_01}$ stands for the probability of measuring 01 when the input state is $\\lvert\\psi\\rangle$. Then we can begin the training stage to minimize the loss function.\n",
    "\n",
    "\n",
    "![qsd](figures/discrimination-fig-circuit.png \"Figure 1: Schematic diagram of state discrimination with LOCCNet.\")\n",
    "<div style=\"text-align:center\">Figure 1: Schematic diagram of state discrimination with LOCCNet. </div>\n",
    "\n",
    "\n",
    "We summarize the workflow below:\n",
    "\n",
    "\n",
    "1. Alice and Bob share a two-qubit system, which state is either $\\lvert\\psi\\rangle$ or $\\lvert\\phi\\rangle$. \n",
    "2. Alice operates a general rotation gate $U_A$ on her qubit.\n",
    "3. Alice measures her qubit on the computational basis, and the result $m_A\\in \\{0, 1\\}$. Then, she communicates with Bob about the measurement result through a classical channel.\n",
    "4.  Bob operates different gates on his qubit depending on Alice's measurement result. If, $m_A=0$ Bob acts $U_{B0}$ on his qubit; If $m_A = 1$, then Bob acts $U_{B1}$. Then, Bob measures his qubit and obtain $m_B \\in \\{0,1\\}$. **Note**: Both $U_{B0}$ and $U_{B1}$ are universal single-qubit gate `u3()`.\n",
    "5. Calculate the loss function $L = p_{\\lvert\\psi\\rangle\\_01}+p_{\\lvert\\psi\\rangle\\_11}+ p_{\\lvert\\phi\\rangle\\_10}+ p_{\\lvert\\phi\\rangle\\_00}$, and use gradient-based optimization methods to minimize it.\n",
    "6. Repeat 1-5 until the loss function converges.\n",
    "7. Input the pre-shared state $\\lvert\\varphi\\rangle$ to make a decision and compare with Charlie's answer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation with Paddle Quantum \n",
    "\n",
    "First, import relevant packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-09T04:11:04.474304Z",
     "start_time": "2021-03-09T04:11:01.071347Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import unitary_group\n",
    "import paddle\n",
    "from paddle_quantum.locc import LoccNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Randomly generate two orthogonal pure states $\\lvert\\psi\\rangle$ and $\\lvert\\phi\\rangle$ by Charlie."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-09T04:11:04.498238Z",
     "start_time": "2021-03-09T04:11:04.478356Z"
    }
   },
   "outputs": [],
   "source": [
    "def states_orthogonal_random(n, num=2):\n",
    "    \n",
    "    # Randomly generate two orthogonal states\n",
    "    assert num <= 2 ** n, \"return too many orthognal states\"\n",
    "    U = unitary_group.rvs(2 ** n)\n",
    "    return_list = []\n",
    "    for i in range(num):\n",
    "        return_list.append(U[i])\n",
    "    return return_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is the main part of our LOCC protocol:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-09T04:11:04.551149Z",
     "start_time": "2021-03-09T04:11:04.506818Z"
    }
   },
   "outputs": [],
   "source": [
    "class Net(LoccNet):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        # Add the first party Alice \n",
    "        # The first parameter 1 stands for how many qubits A holds\n",
    "        # The second parameter records the name of this qubit holder\n",
    "        self.add_new_party(1, party_name='Alice')\n",
    "        \n",
    "        # Add the first party Bob \n",
    "        # The first parameter 1 stands for how many qubits B holds\n",
    "        # The second parameter records the name of this qubit holder\n",
    "        self.add_new_party(1, party_name='Bob')\n",
    "        \n",
    "        # Initialize Alice's parameter\n",
    "        self.theta1 = self.create_parameter(shape=[3],\n",
    "                                            default_initializer=paddle.nn.initializer.Uniform(low=0.0, high=2 * np.pi),\n",
    "                                            dtype=\"float64\")\n",
    "        # Initialize Bob's parameter\n",
    "        # Bob has to prepare two circuits according Alice's measurement result \n",
    "        self.theta2 = self.create_parameter(shape=[3],\n",
    "                                            default_initializer=paddle.nn.initializer.Uniform(low=0.0, high=2 * np.pi),\n",
    "                                            dtype=\"float64\")\n",
    "        self.theta3 = self.create_parameter(shape=[3],\n",
    "                                            default_initializer=paddle.nn.initializer.Uniform(low=0.0, high=2 * np.pi),\n",
    "                                            dtype=\"float64\")\n",
    "        \n",
    "        # Rewrite the input states into density martices\n",
    "        _states = states_orthogonal_random(2)\n",
    "        _states = [paddle.to_tensor(np.outer(init_state, init_state.conjugate())) for init_state in _states]\n",
    "        \n",
    "        # Initialize the system by distributing states\n",
    "        self.set_init_state(_states[0], [('Alice', 0), ('Bob', 0)])\n",
    "        self.psi = self.init_status\n",
    "        self.phi = self.reset_state(self.init_status, _states[1], [('Alice', 0), ('Bob', 0)])\n",
    "\n",
    "    def A_circuit(self, theta, state, res):\n",
    "        # Alice's local operations\n",
    "        cir = self.create_ansatz('Alice')\n",
    "        # Add single-qubit universal gate\n",
    "        cir.u3(*theta, 0)\n",
    "        # Run circuit\n",
    "        after_state = cir.run(state)\n",
    "        # Measure the circuit and record the measurement results \n",
    "        after_state = self.measure(status=after_state, which_qubits=('Alice', 0), results_desired=res)\n",
    "        return after_state\n",
    "\n",
    "    def B_circuit(self, theta, state, res):\n",
    "        # Bob's local operations\n",
    "        cir = self.create_ansatz('Bob')\n",
    "        # Add single-qubit universal gate\n",
    "        cir.u3(*theta, 0)\n",
    "        # Run circuit\n",
    "        after_state = cir.run(state)\n",
    "        # Measure the circuit and record the measurement results \n",
    "        after_state = self.measure(status=after_state, which_qubits=('Bob', 0), results_desired=res)\n",
    "        return after_state\n",
    "    \n",
    "    def forward(self):\n",
    "        # Training steps\n",
    "        # Quantum state after Alice's operation\n",
    "        phi = self.A_circuit(theta=self.theta1, state=self.phi, res=['0', '1'])\n",
    "        psi = self.A_circuit(theta=self.theta1, state=self.psi, res=['0', '1'])\n",
    "\n",
    "        # Calculate the loss function\n",
    "        loss = 0\n",
    "        for each_phi in phi:\n",
    "            if each_phi.measured_result == '0':\n",
    "                # maybe we should call it phi_1 which means predicting the phi to psi\n",
    "                phi_1 = self.B_circuit(self.theta2, state=each_phi, res='1')\n",
    "                loss += phi_1.prob\n",
    "            elif each_phi.measured_result == '1':\n",
    "                psi_1 = self.B_circuit(self.theta3, state=each_phi, res='1')\n",
    "                loss += psi_1.prob\n",
    "\n",
    "        for each_psi in psi:\n",
    "            if each_psi.measured_result == '0':\n",
    "                phi_0 = self.B_circuit(self.theta2, state=each_psi, res='0')\n",
    "                loss += phi_0.prob\n",
    "            elif each_psi.measured_result == '1':\n",
    "                psi_0 = self.B_circuit(self.theta3, state=each_psi, res='0')\n",
    "                loss += psi_0.prob\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def evaluate(self):\n",
    "        # Test step\n",
    "        choice = np.random.choice(['phi', 'psi'])\n",
    "        if choice == 'phi':\n",
    "            self.status = self.phi\n",
    "        else:\n",
    "            self.status = self.psi\n",
    "        print('Charlie chooses the state', choice)\n",
    "\n",
    "        # Alice's operations\n",
    "        status = self.A_circuit(theta=self.theta1, state=self.status, res=['0', '1'])\n",
    "        # Bob's operations \n",
    "        result_0 = list()\n",
    "        result_1 = list()\n",
    "        for each_status in status:\n",
    "            if each_status.measured_result == '0':\n",
    "                phi = self.B_circuit(theta=self.theta2, state=each_status, res=['0', '1'])\n",
    "                result_0.append(phi[0].prob.numpy())\n",
    "                result_0.append(phi[1].prob.numpy())\n",
    "\n",
    "            elif each_status.measured_result == '1':\n",
    "                psi = self.B_circuit(theta=self.theta3, state=each_status, res=['0', '1'])\n",
    "                result_1.append(psi[0].prob.numpy())\n",
    "                result_1.append(psi[1].prob.numpy())\n",
    "        print(\"The probability that Alice and Bob recognize it as phi:\", result_0[0][0] + result_1[0][0])\n",
    "        print(\"The probability that Alice and Bob recognize it as psi:\", result_0[1][0] + result_1[1][0])\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the QNN parameters, and Charlie randomly select one of the two orthogonal states $\\lvert\\psi\\rangle$ and $\\lvert\\phi\\rangle$ and see whether Alice and Bob can distinguish it correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-09T04:11:19.312821Z",
     "start_time": "2021-03-09T04:11:14.759835Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "itr 0: 0.8761168230331675\n",
      "itr 10: 0.20833627335270305\n",
      "itr 20: 0.09979993875586758\n",
      "itr 30: 0.05990778041973435\n",
      "itr 40: 0.02806489116318598\n",
      "itr 50: 0.026191291836946892\n",
      "itr 60: 0.02245618597383097\n",
      "itr 70: 0.01903560177140634\n",
      "itr 80: 0.016262801555360196\n",
      "itr 90: 0.011920717566034985\n",
      "Minimum loss: 0.0077142257760679424\n",
      "======================== test stage ===============================\n",
      "Charlie chooses the state psi\n",
      "The probability that Alice and Bob recognize it as phi: 0.003188464207077198\n",
      "The probability that Alice and Bob recognize it as psi: 0.9968115357929229\n",
      "Charlie chooses the state phi\n",
      "The probability that Alice and Bob recognize it as phi: 0.99590992116251\n",
      "The probability that Alice and Bob recognize it as psi: 0.004090078837489469\n"
     ]
    }
   ],
   "source": [
    "ITR = 100  # Set the number of training iterations\n",
    "LR = 0.1   # Set learning rate\n",
    "SEED = 999 # Fix randome seed for parameters in PQC\n",
    "np.random.seed(SEED)\n",
    "paddle.seed(SEED)\n",
    "\n",
    "net = Net()\n",
    "opt = paddle.optimizer.Adam(learning_rate=LR, parameters=net.parameters())\n",
    "loss_list = list()\n",
    "# Train the LOCC net for ITR iterations by gradient descent\n",
    "for itr in range(ITR):\n",
    "    loss = net()\n",
    "    loss.backward()\n",
    "    opt.minimize(loss)\n",
    "    opt.clear_grad()\n",
    "    loss_list.append(loss.numpy()[0])\n",
    "    if itr % 10 == 0:\n",
    "        print(\"itr \" + str(itr) + \":\", loss.numpy()[0])\n",
    "print(\"Minimum loss:\", loss_list[-1])\n",
    "\n",
    "print(\"======================== test stage ===============================\")\n",
    "np.random.seed(10)\n",
    "net.evaluate()\n",
    "np.random.seed(6)\n",
    "net.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "It can be seen from the simulation results that the trained quantum circuit can distinguish two orthogonal quantum states almost perfectly with an accuracy $>99.9\\%$. There is an interesting question that can we generalize this discrimination scheme by adding more states to the category."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## References\n",
    "\n",
    "[1] Barnett, Stephen M., and Sarah Croke. \"Quantum state discrimination.\" [Advances in Optics and Photonics 1.2 (2009): 238-278.](https://www.osapublishing.org/abstract.cfm?id=176580)\n",
    "\n",
    "[2] Chefles, Anthony. \"Quantum state discrimination.\" [Contemporary Physics 41.6 (2000): 401-424.](https://arxiv.org/abs/quant-ph/0010114)\n",
    "\n",
    "[3] Walgate, Jonathan, et al. \"Local distinguishability of multipartite orthogonal quantum states.\" [Physical Review Letters 85.23 (2000): 4972.](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.85.4972)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
